{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "siBWl3-lEd8x",
        "outputId": "77a7acbf-6f53-4c5c-e713-9968f8b561b8"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow tensorflow_addons tensorflow_datasets tensorflow_hub numpy matplotlib seaborn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HHEXVxQw_g5B"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_addons as tfa"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dgKI5S4DC31Z"
      },
      "outputs": [],
      "source": [
        "# load the whole dataset, for data info\n",
        "all_ds   = tfds.load(\"eurosat\", with_info=True)\n",
        "# load training, testing & validation sets, splitting by 60%, 20% and 20% respectively\n",
        "train_ds = tfds.load(\"eurosat\", split=\"train[:60%]\")\n",
        "test_ds  = tfds.load(\"eurosat\", split=\"train[60%:80%]\")\n",
        "valid_ds = tfds.load(\"eurosat\", split=\"train[80%:]\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AcBsiqtBcwwF"
      },
      "outputs": [],
      "source": [
        "# the class names\n",
        "class_names = all_ds[1].features[\"label\"].names\n",
        "# total number of classes (10)\n",
        "num_classes = len(class_names)\n",
        "num_examples = all_ds[1].splits[\"train\"].num_examples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 621
        },
        "id": "MjrKwojyo-kE",
        "outputId": "c8ef6c03-9035-4785-d236-c9c104412e34"
      },
      "outputs": [],
      "source": [
        "# make a plot for number of samples on each class\n",
        "fig, ax = plt.subplots(1, 1, figsize=(14,10))\n",
        "labels, counts = np.unique(np.fromiter(all_ds[0][\"train\"].map(lambda x: x[\"label\"]), np.int32), \n",
        "                       return_counts=True)\n",
        "\n",
        "plt.ylabel('Counts')\n",
        "plt.xlabel('Labels')\n",
        "sns.barplot(x = [class_names[l] for l in labels], y = counts, ax=ax) \n",
        "for i, x_ in enumerate(labels):\n",
        "  ax.text(x_-0.2, counts[i]+5, counts[i])\n",
        "# set the title\n",
        "ax.set_title(\"Bar Plot showing Number of Samples on Each Class\")\n",
        "# save the image\n",
        "# plt.savefig(\"class_samples.png\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g60GgX9hEPBc"
      },
      "outputs": [],
      "source": [
        "def prepare_for_training(ds, cache=True, batch_size=64, shuffle_buffer_size=1000):\n",
        "  if cache:\n",
        "    if isinstance(cache, str):\n",
        "      ds = ds.cache(cache)\n",
        "    else:\n",
        "      ds = ds.cache()\n",
        "  ds = ds.map(lambda d: (d[\"image\"], tf.one_hot(d[\"label\"], num_classes)))\n",
        "  # shuffle the dataset\n",
        "  ds = ds.shuffle(buffer_size=shuffle_buffer_size)\n",
        "  # Repeat forever\n",
        "  ds = ds.repeat()\n",
        "  # split to batches\n",
        "  ds = ds.batch(batch_size)\n",
        "  # `prefetch` lets the dataset fetch batches in the background while the model\n",
        "  # is training.\n",
        "  ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
        "  return ds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TP_CGr3kNw0c"
      },
      "outputs": [],
      "source": [
        "batch_size = 64\n",
        "\n",
        "# preprocess training & validation sets\n",
        "train_ds = prepare_for_training(train_ds, batch_size=batch_size)\n",
        "valid_ds = prepare_for_training(valid_ds, batch_size=batch_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vP-ioWj9e37z",
        "outputId": "2b455894-72e9-4771-905c-5b58701d3a98"
      },
      "outputs": [],
      "source": [
        "# validating shapes\n",
        "for el in valid_ds.take(1):\n",
        "  print(el[0].shape, el[1].shape)\n",
        "for el in train_ds.take(1):\n",
        "  print(el[0].shape, el[1].shape)  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cIW7hbHhOVq0"
      },
      "outputs": [],
      "source": [
        "# take the first batch of the training set\n",
        "batch = next(iter(train_ds))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 473
        },
        "id": "TNRbCVp6Na1A",
        "outputId": "412e7412-86f9-467d-d88d-eaa9fd12600a"
      },
      "outputs": [],
      "source": [
        "def show_batch(batch):\n",
        "  plt.figure(figsize=(16, 16))\n",
        "  for n in range(min(32, batch_size)):\n",
        "      ax = plt.subplot(batch_size//8, 8, n + 1)\n",
        "      # show the image\n",
        "      plt.imshow(batch[0][n])\n",
        "      # and put the corresponding label as title upper to the image\n",
        "      plt.title(class_names[tf.argmax(batch[1][n].numpy())])\n",
        "      plt.axis('off')\n",
        "      plt.savefig(\"sample-images.png\")\n",
        "\n",
        "# showing a batch of images along with labels\n",
        "show_batch(batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JMVzjuqmoOcB"
      },
      "outputs": [],
      "source": [
        "model_url = \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/feature_vector/2\"\n",
        "\n",
        "# download & load the layer as a feature vector\n",
        "keras_layer = hub.KerasLayer(model_url, output_shape=[1280], trainable=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uhKLvFpkfiCr"
      },
      "outputs": [],
      "source": [
        "m = tf.keras.Sequential([\n",
        "  keras_layer,\n",
        "  tf.keras.layers.Dense(num_classes, activation=\"softmax\")\n",
        "])\n",
        "# build the model with input image shape as (64, 64, 3)\n",
        "m.build([None, 64, 64, 3])\n",
        "m.compile(\n",
        "    loss=\"categorical_crossentropy\", \n",
        "    optimizer=\"adam\", \n",
        "    metrics=[\"accuracy\", tfa.metrics.F1Score(num_classes)]\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-QMzJ4-fhD-B",
        "outputId": "e9101d76-e18c-42e6-b27c-08c1e037f81a"
      },
      "outputs": [],
      "source": [
        "m.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I0vYaHjPhUDF"
      },
      "outputs": [],
      "source": [
        "model_name = \"satellite-classification\"\n",
        "model_path = os.path.join(\"results\", model_name + \".h5\")\n",
        "model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True, verbose=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IP93lr9DdH9J"
      },
      "outputs": [],
      "source": [
        "n_training_steps   = int(num_examples * 0.6) // batch_size\n",
        "n_validation_steps = int(num_examples * 0.2) // batch_size"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mUp9ocf-hnlC",
        "outputId": "7f372750-7fa9-4dad-8286-bb324ff56219"
      },
      "outputs": [],
      "source": [
        "history = m.fit(\n",
        "    train_ds, validation_data=valid_ds,\n",
        "    steps_per_epoch=n_training_steps,\n",
        "    validation_steps=n_validation_steps,\n",
        "    verbose=1, epochs=5, \n",
        "    callbacks=[model_checkpoint]\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9kjuwUEEXWQ5"
      },
      "outputs": [],
      "source": [
        "# number of testing steps\n",
        "n_testing_steps = int(all_ds[1].splits[\"train\"].num_examples * 0.2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6U_pHoLGKj-f"
      },
      "outputs": [],
      "source": [
        "m.load_weights(model_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JWL3NBXfXhTQ",
        "outputId": "2786d6a0-af1f-47ae-9927-a495100989cc"
      },
      "outputs": [],
      "source": [
        "# get all testing images as NumPy array\n",
        "images = np.array([ d[\"image\"] for d in test_ds.take(n_testing_steps) ])\n",
        "print(\"images.shape:\", images.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sgq55TrQVtr3",
        "outputId": "f4c51aaa-8229-444f-ee76-1d9385b94f3e"
      },
      "outputs": [],
      "source": [
        "# get all testing labels as NumPy array\n",
        "labels = np.array([ d[\"label\"] for d in test_ds.take(n_testing_steps) ])\n",
        "print(\"labels.shape:\", labels.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5UbOsNtmXDqR",
        "outputId": "a8a5d860-b570-4639-fcf4-36b56db97421"
      },
      "outputs": [],
      "source": [
        "# feed the images to get predictions\n",
        "predictions = m.predict(images)\n",
        "# perform argmax to get class index\n",
        "predictions = np.argmax(predictions, axis=1)\n",
        "print(\"predictions.shape:\", predictions.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GX-GkI9Gy1hS",
        "outputId": "5ecd1703-093c-4cb8-a8fa-2111e7fb670b"
      },
      "outputs": [],
      "source": [
        "from sklearn.metrics import f1_score\n",
        "\n",
        "accuracy = tf.keras.metrics.Accuracy()\n",
        "accuracy.update_state(labels, predictions)\n",
        "print(\"Accuracy:\", accuracy.result().numpy())\n",
        "print(\"F1 Score:\", f1_score(labels, predictions, average=\"macro\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 736
        },
        "id": "yszXAmfdVcOA",
        "outputId": "56bb4353-4c37-405d-fc18-9f17e19a8c71"
      },
      "outputs": [],
      "source": [
        "# compute the confusion matrix\n",
        "cmn = tf.math.confusion_matrix(labels, predictions).numpy()\n",
        "# normalize the matrix to be in percentages\n",
        "cmn = cmn.astype('float') / cmn.sum(axis=0)[:, np.newaxis]\n",
        "# make a plot for the confusion matrix\n",
        "fig, ax = plt.subplots(figsize=(10,10))\n",
        "sns.heatmap(cmn, annot=True, fmt='.2f', \n",
        "            xticklabels=[f\"pred_{c}\" for c in class_names], \n",
        "            yticklabels=[f\"true_{c}\" for c in class_names],\n",
        "            # cmap=\"Blues\"\n",
        "            cmap=\"rocket_r\"\n",
        "            )\n",
        "plt.ylabel('Actual')\n",
        "plt.xlabel('Predicted')\n",
        "# plot the resulting confusion matrix\n",
        "plt.savefig(\"confusion-matrix.png\")\n",
        "# plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 808
        },
        "id": "txdBO11IbF8e",
        "outputId": "5de1a7a1-1039-475a-8bfa-3e02a4e984a1"
      },
      "outputs": [],
      "source": [
        "def show_predicted_samples():\n",
        "  plt.figure(figsize=(14, 14))\n",
        "  for n in range(64):\n",
        "      ax = plt.subplot(8, 8, n + 1)\n",
        "      # show the image\n",
        "      plt.imshow(images[n])\n",
        "      # and put the corresponding label as title upper to the image\n",
        "      if predictions[n] == labels[n]:\n",
        "        # correct prediction\n",
        "        ax.set_title(class_names[predictions[n]], color=\"green\")\n",
        "      else:\n",
        "        # wrong prediction\n",
        "        ax.set_title(f\"{class_names[predictions[n]]}/T:{class_names[labels[n]]}\", color=\"red\")\n",
        "      plt.axis('off')\n",
        "      plt.savefig(\"predicted-sample-images.png\")\n",
        "\n",
        "# showing a batch of images along with predictions labels\n",
        "show_predicted_samples()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xgd0y1Ul5aQi"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Satellite-Image-Classification-with-TensorFlow_PythonCode.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
